#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 25 21:35:12 2022

@author: matthewroberts
"""

import os
import sys
import numpy as np
import glob

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import matplotlib.gridspec as gridspec
import matplotlib.colors as colors
import matplotlib.cm as cm

import cartopy.crs as crs
import cartopy.io.img_tiles as cimgt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter

import pyproj
import pyproj
from pyproj import Geod

import boto3
from boto3.session import Session
import datetime as dt
import wrf
from netCDF4 import Dataset
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

##########################################
# Set up functions
##########################################
def latlon_convert(proj_info,lat_rad_1d,lon_rad_1d):
    """
    Remaps native projection and radials from GOES to lat/lon format
    """
    lon_origin = proj_info.longitude_of_projection_origin
    H = proj_info.perspective_point_height+proj_info.semi_major_axis
    r_eq = proj_info.semi_major_axis
    r_pol = proj_info.semi_minor_axis

    # create meshgrid filled with radian angles
    lat_rad,lon_rad = np.meshgrid(lat_rad_1d,lon_rad_1d)

    # lat/lon calc routine from satellite radian angle vectors
    lambda_0 = (lon_origin*np.pi)/180.0

    a_var = np.power(np.sin(lat_rad),2.0) + (np.power(np.cos(lat_rad),2.0)*(np.power(np.cos(lon_rad),2.0)+(((r_eq*r_eq)/(r_pol*r_pol))*np.power(np.sin(lon_rad),2.0))))
    b_var = -2.0*H*np.cos(lat_rad)*np.cos(lon_rad)
    c_var = (H**2.0)-(r_eq**2.0)

    r_s = (-1.0*b_var - np.sqrt((b_var**2)-(4.0*a_var*c_var)))/(2.0*a_var)

    s_x = r_s*np.cos(lat_rad)*np.cos(lon_rad)
    s_y = - r_s*np.sin(lat_rad)
    s_z = r_s*np.cos(lat_rad)*np.sin(lon_rad)

    # latitude and longitude projection for plotting data on traditional lat/lon maps
    lat = (180.0/np.pi)*(np.arctan(((r_eq*r_eq)/(r_pol*r_pol))*((s_z/np.sqrt(((H-s_x)*(H-s_x))+(s_y*s_y))))))
    lon = (lambda_0 - np.arctan(s_y/(H-s_x)))*(180.0/np.pi)

    return lat, lon

def plumeProp(latVar,lonVar,coords,sens):
    """
    Defines area to look at plume/fire properties.
    latVar/lonVar = x/y variable names
    coords = [lon,lat]  #Centroid of plume area
    sens = [lon,lat]    #Distance from centroid to draw box (deg)
    """
    # Will fill with "good" calculated indexes
    lonIdx = []
    latIdx = []
    # Loop through each row
    for idx in range(len(latVar[:,0])):
        # Verify data is there before doing operations.
        # Useful for fringes of satellite viewing area
        if (isinstance(latVar[idx,0], (np.floating, float)) == True):
            # Find matching coords in row and append to appropriate lists
            for i, (a, b) in enumerate(zip(latVar[idx,:], lonVar[idx,:])):
                if ((a > (coords[1]-sens[1])) & (a < (coords[1]+sens[1])) &
                    (b > (coords[0]-sens[0])) & (b < (coords[0]+sens[0]))):
                    #print(idx, i, a, b)
                    latIdx.append(i)
                    lonIdx.append(idx)
    # Get max/min of lists to find lowest/highest indexes for box corners (lon1=x,lat1=y)
    lat1,lat2 = np.min(np.asarray(latIdx)), np.max(np.asarray(latIdx))
    lon1,lon2 = np.min(np.asarray(lonIdx)), np.max(np.asarray(lonIdx))

    return lon1,lon2,lat1,lat2

# Creates shaded relief background on maps
from cartopy.io.img_tiles import GoogleTiles
class ShadedReliefESRI(GoogleTiles):
    # shaded relief
    def _image_url(self, tile):
        x, y, z = tile
        url = ('https://server.arcgisonline.com/ArcGIS/rest/services/' \
               'World_Shaded_Relief/MapServer/tile/{z}/{y}/{x}.jpg').format(
               z=z, y=y, x=x)
        return url
    
def planarWind(uVal,vVal,bearing):
    # Get angle difference between environmental wind and cross
    # section angle (All angles assumed starting at N on 0-360 plane)
    D = bearing - np.rad2deg((np.arctan2(uVal,vVal)))
    pWind = (np.sqrt(uVal**2+vVal**2))*np.cos(np.deg2rad(D))
    return pWind

def relax_zone_remover (input, sr):
    output = input
    for _ in range(sr):
        output = np.delete(output, -1, 0)
        output = np.delete(output, -1, 1)
    return output

start_time = dt.datetime.now()
print('\nProgram Started: {0}'.format(start_time))
print('===================')

#########################
# Inputs
#########################

keyword = 'bear_fire' #for naming plots
# keyword = 'caldor_fire' #for naming plots

PlotType = ['2panel', '4panel'] # GOES vs Control
# PlotType = '4panel' # GOES vs Fuelx1, x4, x8

mainpath = '/Users/matthewroberts/Documents/Projects/LEAPHI'
goespath = mainpath+'/data/goes_data'
filepath = mainpath+'/data/'+keyword+'/_CONTROL' #data files location
savepath = mainpath+'/plots/'+keyword+'/flame_depth' #where to save figure
filelist_og = sorted(glob.glob(filepath+'/wrfout_d03*'))

# AWS credentials
ACCESS_KEY='xxxxx'
SECRET_KEY='xxxxx'

filelist = []
for f in filelist_og:
    filelist.append(os.path.basename(f))

# testing
# filelist = [filelist[-7]]

filepath = mainpath+'/data/'+keyword
fuels = ['_CONTROL', '_FUELx1', '_FUELx4', '_FUELx8']

# Limit the extent of the map/GOES data to a small longitude/latitude range.
if (keyword == 'bear_fire'):
    extent = [-121.4, -120.86, 39.5, 39.91]
    # extent = [-121.6, -119.8, 39.2, 40.1]
if (keyword == 'caldor_fire'):
    extent = [-120.6, -120.32, 38.52, 38.78]
    # extent = [-120.58, -120.35, 38.58, 38.74]

for p in range(len(PlotType)):
    time_arr = []
    for f in range(len(filelist)):
        
        plot_flat = []
        plot_flon = []
        plot_hflux = []
            
        for ff in fuels:
            
            wrf_file = Dataset(filepath+'/'+ff+'/'+filelist[f],mode='r')
        
            # Create timestamp and time array
            time = wrf.extract_times(wrf_file,wrf.ALL_TIMES)
            tstamp = str(pd.Timestamp(time[0]))
            dt_obj = dt.datetime.strptime(tstamp,"%Y-%m-%d %H:%M:%S")
            if (ff == '_FUELx8'):
                time_arr.append(tstamp)
            print('\nOpening '+tstamp+'...', end="", flush=True)
        
            # Standard lat/lon grid
            terrain = wrf.getvar(wrf_file, "HGT", timeidx=-1)
            lat = wrf.getvar(wrf_file, "XLAT", timeidx=-1)
            lon = wrf.getvar(wrf_file, "XLONG", timeidx=-1)
        
            # Fire grid variables
            fgrnhfx = wrf.getvar(wrf_file, "FGRNHFX", timeidx=-1) #Heat flux ground fire
            fcanhfx = wrf.getvar(wrf_file, "FCANHFX", timeidx=-1) #Heat flux crown fire
            f_hfx = fgrnhfx+fcanhfx #total fire heat flux
            f_hfx = np.ma.masked_array(f_hfx, f_hfx <= 1)
            lfn1 = wrf.getvar(wrf_file, 'LFN', timeidx=-1)
            fuel = wrf.getvar(wrf_file, 'NFUEL_CAT', timeidx=-1)
            # loading lat/lons of fire mesh
            flat = wrf.getvar(wrf_file, 'FXLAT', timeidx=-1)
            flon = wrf.getvar(wrf_file, 'FXLONG', timeidx=-1)
        
            # subgrid ratio
            sr = int(wrf_file.dimensions['west_east_subgrid'].size)/int(wrf_file.dimensions['west_east_stag'].size)
            dx = int(wrf_file.DX)/sr
            # removing the relaxation zones of the level-set function
            lfn1 = relax_zone_remover(lfn1, int(sr))
            f_hfx = relax_zone_remover(f_hfx, int(sr))
            fuel = relax_zone_remover(fuel, int(sr))
            flon = relax_zone_remover(flon, int(sr))
            flat = relax_zone_remover(flat, int(sr))
    
            #########################
            # Regridding
            #########################        
            print('Regridding data...')
            
            coarse_val = 2000. #2km regrid
                
            # Suppose the 2D array is coarsened by 1000 meters
            coarseness = int(coarse_val/dx) #regrid to 1km GOES pixels
            
            # Start by finding the next highest multiple of coarseness
            shape = np.array(f_hfx.shape, dtype=int)
            new_shape = coarseness * np.ceil(shape / coarseness).astype(int)
            
            # Create the zero-padded array and assign it with the old density
            zp_f_hfx, zp_flon, zp_flat = np.zeros(new_shape), np.zeros(new_shape), np.zeros(new_shape)
            zp_f_hfx[:shape[0], :shape[1]] = f_hfx
            
            # Find mean of old density inside new coarse density and fill
            # heat flux
            temp = zp_f_hfx.reshape((new_shape[0] // coarseness, coarseness,
                                     new_shape[1] // coarseness, coarseness))
            coarse_f_hfx = np.mean(temp, axis=(1,3))
            coarse_f_hfx = np.ma.masked_less_equal(coarse_f_hfx, 0)
            # W/m2 to kW/m2
            coarse_f_hfx = coarse_f_hfx/1000.
            # Set lower threshold of 10 kW/m2 for 1km grid (proxy for 1.6 um GOES)
            coarse_f_hfx[coarse_f_hfx > 50.] = 50
            coarse_f_hfx[coarse_f_hfx < .001] = 0
            
            coarse_flat = flat[::coarseness,::coarseness]
            coarse_flon = flon[::coarseness,::coarseness]
            
            plot_flat.append(coarse_flat)
            plot_flon.append(coarse_flon)
            plot_hflux.append(coarse_f_hfx)
                
        #########################
        # GOES FRP
        #########################
        print('\nFetching GOES data...')
        
        satname = 'goes17' #goes16 or goes17
        # Fire product
        product = 'ABI-L2-FDCC'
        # Scan mode
        mode = 'M6' #usually M6 but M3 for older data
        
        ##########################################
        # AWS credentials and open session
        ##########################################
        # GOES data exists pull data
        try:
            session = Session(aws_access_key_id=ACCESS_KEY,
                              aws_secret_access_key=SECRET_KEY)
            s3 = session.client('s3')
        
            # Name of AWS bucket
            bucket = 'noaa-'+satname
        
            # If top of the hour look in the previous hour for 00:59 file
            if (dt_obj.minute == 0):
                dt_obj1 = dt_obj-dt.timedelta(hours=1)
                prefix = product + '/' + str(dt_obj1.year) + '/' + dt_obj1.strftime('%j').zfill(3) + '/' + str(dt_obj1.hour).zfill(2) + '/OR_' + product + '-' + mode
            else:
                prefix = product + '/' + str(dt_obj.year) + '/' + dt_obj.strftime('%j').zfill(3) + '/' + str(dt_obj.hour).zfill(2) + '/OR_' + product + '-' + mode
            # Create file list to prepare for download
            goesfilelist = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)['Contents']
            
            goesfiles = []
            # Make list of all files in the hour
            for key in range(len(goesfilelist)):
                #print(key['Key'])
                goesfiles.append(goesfilelist[key]['LastModified'].replace(tzinfo=None))
            # Find closest file to WRF output time
            goestime = min(goesfiles, key=lambda d: abs(d - dt_obj))
            goesidx = goesfiles.index(goestime)
            
            # Download the file
            print('\nDownloading file from AWS...', end='', flush=True)
            s3.download_file(bucket, goesfilelist[goesidx]['Key'], goespath+'/satfile.nc')
        
            filename = goespath+'/satfile.nc'
            data = Dataset(filename,'r')
            
            ##########################################
            # Fire Radiative Power
            ##########################################
            # Scan time
            midpoint = float(data.variables['t'][:])
            scan_mid = dt.datetime(2000,1,1,12) + dt.timedelta(seconds=midpoint)
            print(dt.datetime.strftime(scan_mid,'%Y%m%d_%H%M'))
        
            # GOES-17 projection info and retrieving relevant constants
            projInfo = data.variables['goes_imager_projection']
            latRad = data.variables['x'][:]
            lonRad = data.variables['y'][:]
        
            # Create x,y arrays from projection and radian data
            y, x = latlon_convert(projInfo,latRad,lonRad)
            # data and saturation info
            fireFRP = data.variables['Power'][:]
            fireFlags = data.variables['Mask'][:]
        
            cloudmask = np.zeros(fireFRP.shape)
            satmask = np.zeros(fireFRP.shape)
            firemask = np.zeros(fireFRP.shape)
        
            # Set cloud contaminated pixel flags (12,32) to mean FRP of good fire pixels
            fireFRP[fireFlags == 12] = np.ma.mean(fireFRP)
            fireFRP[fireFlags == 32] = np.ma.mean(fireFRP)
            # flag cloud contaminated pixels
            cloudmask[fireFlags == 12] = 99
            cloudmask[fireFlags == 32] = 99
            
            # Set saturated pixel flags (11,31) to max FRP of good fire pixels
            fireFRP[fireFlags == 11] = fireFRP.max()
            fireFRP[fireFlags == 31] = fireFRP.max()
            # flag saturated pixels
            satmask[fireFlags == 11] = 99
            satmask[fireFlags == 31] = 99
            
            # flag fire pixels
            firemask[fireFRP > 0] = 99
        
            # close file when finished
            data.close()
            data = None
            
            print('\nCropping domain...')
            centroid = [((extent[1]-extent[0])/2.)+extent[0],((extent[3]-extent[2])/2.)+extent[2]]
            xvals1FRP,xvals2FRP,yvals1FRP,yvals2FRP = plumeProp(y,x,centroid,[.4,.4])
            # New arrays with data only from region of interest
            fireX = x[xvals1FRP:xvals2FRP,yvals1FRP:yvals2FRP]
            fireY = y[xvals1FRP:xvals2FRP,yvals1FRP:yvals2FRP]
            fireFRP = fireFRP[xvals1FRP:xvals2FRP,yvals1FRP:yvals2FRP]
            
            # make no FRP nan instead of -9
            newFRP = np.ma.masked_less_equal(fireFRP, 0) 
            newFRP = newFRP.filled(np.nan) # bad data=nan
            
            print(np.nanmax(newFRP))
            # Convert to "sensible" heat flux
            newFRP = newFRP*1000000. #FRP - MW to W
            newFRP = newFRP/(coarse_val**2.) # W to W/m2, divide by grid size squared
            newFRP = newFRP*10./1000. # W/m2 to kW/m2
            print(np.nanmax(newFRP))
        # GOES data doesn't exist make blank arrays
        except:
            fireX = np.array([])
            fireY = np.array([])
            newFRP = np.array([])
        
    #%%    
        #########################
        # Plotting
        #########################
        print('\nPlotting...')
        
        cmap1 = cm.get_cmap('hot',lut=20)
        
        if (PlotType[p] == '4panel'):
            
            plot_hflux = plot_hflux[1:]
            plot_flon = plot_flon[1:]
            plot_flat = plot_flat[1:]
            
            fig = plt.figure(1,figsize=(10,5))
            spec = gridspec.GridSpec(nrows=40, ncols=125, figure=fig)
            
            # Create a Terrain instance.
            esri_terrain = ShadedReliefESRI()
        
            ax1 = fig.add_subplot(spec[4:37,2:30], projection=ccrs.PlateCarree()) #rows,cols
            ax2 = fig.add_subplot(spec[4:37,32:60], projection=ccrs.PlateCarree()) #rows,cols
            ax3 = fig.add_subplot(spec[4:37,62:90], projection=ccrs.PlateCarree()) #rows,cols
            ax4 = fig.add_subplot(spec[4:37,92:120], projection=ccrs.PlateCarree())
            cax = fig.add_subplot(spec[14:26,122:124])
            axVals = [ax2, ax3, ax4]
            
            ax1.set_extent(extent, crs=ccrs.PlateCarree())
            
            ax1.add_image(esri_terrain, 12)
            goes_plot = ax1.pcolormesh(fireX, fireY, newFRP, vmin=0, vmax=10,#alpha=.5,
                                              cmap=cmap1,transform=ccrs.PlateCarree(), zorder=11)
            
            ax1.set_xticks(np.linspace(extent[0]-.03,extent[1]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set longitude indicators
            ax1.set_yticks(np.linspace(extent[2]-.03,extent[3]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set latitude indicators
            lon_formatter = LongitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}',dateline_direction_label=True) # format lons
            lat_formatter = LatitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}') # format lats
            ax1.xaxis.set_major_formatter(lon_formatter) # set lons
            ax1.yaxis.set_major_formatter(lat_formatter) # set lats
            ax1.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, rotation=20, bottom=True, top=True, zorder=99)
            ax1.yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, rotation=20, left=True, right=True, zorder=99)
            
            # Show countries and color land
            ax1.add_feature(cfeature.LAND, facecolor='k', alpha=.3, zorder=10)
            #Download county shapefiles
            reader = shpreader.Reader('/Users/matthewroberts/Documents/Data/shapefiles/countyl010g_shp/countyl010g.shp')
            counties = list(reader.geometries())
            COUNTIES = cfeature.ShapelyFeature(counties, ccrs.PlateCarree())
            ax1.add_feature(COUNTIES, facecolor='none', edgecolor='k', linewidth=1, alpha=.4, zorder=13)
            
            ax1.outline_patch.set_linewidth(2)
            ax1.outline_patch.set_zorder(99)
            
            ax1.set_title(dt.datetime.strftime(goestime,'%m-%d-%Y %H%:%M'), fontsize=12, fontweight='bold')
                
            # plt.colorbar(goes_plot)
            # plt.show()
            # sys.exit()
            
            for xx in range(len(plot_hflux)):
                
                axVals[xx].set_extent(extent, crs=ccrs.PlateCarree())
            
                # Add the Stamen data at zoom level 12 (max).
                axVals[xx].add_image(esri_terrain, 12)
                
                ###### NORMALIZING & THRESHOLDING ######
                # # mask small values
                # fire = np.ma.masked_array(plot_hflux[xx], plot_hflux[xx] <= 1)
                # # W to kW
                # fire = (fire/1000.)
                # # Mask values <10 kW/m2 to simulate 1.6 um in 1 km grid
                # if (xx == 0):
                #     fire = np.ma.masked_array(plot_hflux[xx], plot_hflux[xx] < 10000.)
                #     fire = fire.filled(fill_value=0)
                # if (xx != 0):
                #     fire = np.ma.masked_array(plot_hflux[xx], plot_hflux[xx] < 1.)
                #     fire = fire.filled(fill_value=0)
                
                print(plot_hflux[xx].max())
                fire = plot_hflux[xx]
                # # Normalize values
                # fire = plot_hflux[xx]/np.nanmax(plot_hflux[xx])
                # Make values <0 NaNs so they plot nice
                fire[fire <= 0] = np.nan
                    
                flux_plot = axVals[xx].pcolormesh(plot_flon[xx], plot_flat[xx], fire, vmin=0, vmax=10,#alpha=.5,
                                                  cmap=cmap1,transform=ccrs.PlateCarree(), zorder=11)
            
                # fireTU5_idx = np.where(fuel == 165)
                # newFire = fire[fireTU5_idx[0],fireTU5_idx[1]]
                # print('Fuelx1 Max TU5 Heat Flux: '+str(np.ma.max(newFire/1000.)))
                
                # # Draw the fire contours
                # ax1.contour(flon, flat, lfn1, 0, colors='k', linewidth=2,  transform=ccrs.PlateCarree(), zorder=12)
                    
                axVals[xx].set_xticks(np.linspace(extent[0]-.03,extent[1]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set longitude indicators
                axVals[xx].set_yticks(np.linspace(extent[2]-.03,extent[3]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set latitude indicators
                lon_formatter = LongitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}',dateline_direction_label=True) # format lons
                lat_formatter = LatitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}') # format lats
                axVals[xx].xaxis.set_major_formatter(lon_formatter) # set lons
                # axVals[xx].yaxis.set_major_formatter(lat_formatter) # set lats
                axVals[xx].xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, rotation=20, bottom=True, top=True, zorder=99)
                axVals[xx].yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=0, labelcolor='white', rotation=20, left=True, right=True, zorder=99)
                
                # Show countries and color land
                axVals[xx].add_feature(cfeature.LAND, facecolor='k', alpha=.3, zorder=10)
                #Download county shapefiles
                reader = shpreader.Reader('/Users/matthewroberts/Documents/Data/shapefiles/countyl010g_shp/countyl010g.shp')
                counties = list(reader.geometries())
                COUNTIES = cfeature.ShapelyFeature(counties, ccrs.PlateCarree())
                axVals[xx].add_feature(COUNTIES, facecolor='none', edgecolor='k', linewidth=1, alpha=.4, zorder=13)
                
                axVals[xx].outline_patch.set_linewidth(2)
                axVals[xx].outline_patch.set_zorder(99)
                    
                # Title and axis labels
                axVals[xx].set_title(dt.datetime.strftime(dt_obj,'%m-%d-%Y %H%:%M'), fontsize=12, fontweight='bold')
            
            cbar = fig.colorbar(flux_plot, cax=cax, orientation='vertical')
            ticks = [0,2,4,6,8,10]
            cbar.set_ticks(ticks)
            cbar.set_label(r'$\mathregular{H_{s} [kW m^{-2}]}$', fontsize=8, fontweight='heavy')
            cbar.ax.xaxis.set_label_position('bottom') #if horizontal
            
            # plt.show()
                
            # sys.exit()
            
            # ax1.set_title('Fuelx1', loc='center', fontsize=12, fontweight='bold',zorder=10)
            
        
                # ax3.set_title(tstamp+' UTC\n', loc='center', fontsize=12, zorder=10)
                
                ##################
            #     plt.show()
            # sys.exit()
        
            savetime = dt.datetime.strptime(tstamp, "%Y-%m-%d %H:%M:%S")
            savetime = dt.datetime.strftime(savetime, "%Y%m%d_%H%M")
            # save
            plt.savefig(savepath+'/'+keyword+'_GOEScomp_'+PlotType[p]+ff+'_'+savetime+'.png',bbox_inches='tight',dpi=260)
            plt.close('all')
            print('Saved '+keyword+'_GOEScomp_'+PlotType[p]+ff+'_'+savetime+'.png')
            
            # sys.exit()
    
        if (PlotType[p] == '2panel'):
            
            # xx = 0
            plot_hflux = plot_hflux[0]
            plot_flon = plot_flon[0]
            plot_flat = plot_flat[0]
            
            fig = plt.figure(1,figsize=(10,5))
            spec = gridspec.GridSpec(nrows=40, ncols=80, figure=fig)
            
            # Create a Terrain instance.
            esri_terrain = ShadedReliefESRI()
        
            ax1 = fig.add_subplot(spec[7:33,2:25], projection=ccrs.PlateCarree()) #rows,cols
            ax2 = fig.add_subplot(spec[7:33,32:55], projection=ccrs.PlateCarree()) #rows,cols
            cax = fig.add_subplot(spec[14:26,57:59])
            
            ax1.set_extent(extent, crs=ccrs.PlateCarree())
            
            ax1.add_image(esri_terrain, 12)
            goes_plot = ax1.pcolormesh(fireX, fireY, newFRP, vmin=0, vmax=10,#alpha=.5,
                                              cmap=cmap1,transform=ccrs.PlateCarree(), zorder=11)
            
            ax1.set_xticks(np.linspace(extent[0]-.03,extent[1]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set longitude indicators
            ax1.set_yticks(np.linspace(extent[2]-.03,extent[3]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set latitude indicators
            lon_formatter = LongitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}',dateline_direction_label=True) # format lons
            lat_formatter = LatitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}') # format lats
            ax1.xaxis.set_major_formatter(lon_formatter) # set lons
            ax1.yaxis.set_major_formatter(lat_formatter) # set lats
            ax1.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, rotation=20, bottom=True, top=True, zorder=99)
            ax1.yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, rotation=20, left=True, right=True, zorder=99)
            
            # Show countries and color land
            ax1.add_feature(cfeature.LAND, facecolor='k', alpha=.3, zorder=10)
            #Download county shapefiles
            reader = shpreader.Reader('/Users/matthewroberts/Documents/Data/shapefiles/countyl010g_shp/countyl010g.shp')
            counties = list(reader.geometries())
            COUNTIES = cfeature.ShapelyFeature(counties, ccrs.PlateCarree())
            ax1.add_feature(COUNTIES, facecolor='none', edgecolor='k', linewidth=1, alpha=.4, zorder=13)
            
            ax1.outline_patch.set_linewidth(2)
            ax1.outline_patch.set_zorder(99)
            
            ax1.set_title(dt.datetime.strftime(goestime,'%m-%d-%Y %H%:%M'), fontsize=12, fontweight='bold')
            
            # plt.colorbar(goes_plot)
            # plt.show()
            # sys.exit()
            
            ax2.set_extent(extent, crs=ccrs.PlateCarree())
        
            # Add the Stamen data at zoom level 12 (max).
            ax2.add_image(esri_terrain, 12)
            
            ###### NORMALIZING & THRESHOLDING ######
            # # mask small values
            # fire = np.ma.masked_array(plot_hflux[xx], plot_hflux[xx] <= 1)
            # # W to kW
            # fire = (fire/1000.)
            # # Mask values <10 kW/m2 to simulate 1.6 um in 1 km grid
            # if (xx == 0):
            #     fire = np.ma.masked_array(plot_hflux[xx], plot_hflux[xx] < 10000.)
            #     fire = fire.filled(fill_value=0)
            # if (xx != 0):
            #     fire = np.ma.masked_array(plot_hflux[xx], plot_hflux[xx] < 1.)
            #     fire = fire.filled(fill_value=0)
            
            print(np.nanmax(plot_hflux))
            fire = plot_hflux
            # # Normalize values
            # fire = plot_hflux[xx]/np.nanmax(plot_hflux[xx])
            # Make values <0 NaNs so they plot nice
            fire[fire <= 0] = np.nan
                
            flux_plot = ax2.pcolormesh(plot_flon, plot_flat, fire, vmin=0, vmax=10,#alpha=.5,
                                              cmap=cmap1,transform=ccrs.PlateCarree(), zorder=11)
        
            # fireTU5_idx = np.where(fuel == 165)
            # newFire = fire[fireTU5_idx[0],fireTU5_idx[1]]
            # print('Fuelx1 Max TU5 Heat Flux: '+str(np.ma.max(newFire/1000.)))
            
            # # Draw the fire contours
            # ax1.contour(flon, flat, lfn1, 0, colors='k', linewidth=2,  transform=ccrs.PlateCarree(), zorder=12)
                
            ax2.set_xticks(np.linspace(extent[0]-.03,extent[1]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set longitude indicators
            ax2.set_yticks(np.linspace(extent[2]-.03,extent[3]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set latitude indicators
            lon_formatter = LongitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}',dateline_direction_label=True) # format lons
            lat_formatter = LatitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}') # format lats
            ax2.xaxis.set_major_formatter(lon_formatter) # set lons
            # axVals[xx].yaxis.set_major_formatter(lat_formatter) # set lats
            ax2.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, rotation=20, bottom=True, top=True, zorder=99)
            ax2.yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=0, labelcolor='white', rotation=20, left=True, right=True, zorder=99)
            
            # Show countries and color land
            ax2.add_feature(cfeature.LAND, facecolor='k', alpha=.3, zorder=10)
            #Download county shapefiles
            reader = shpreader.Reader('/Users/matthewroberts/Documents/Data/shapefiles/countyl010g_shp/countyl010g.shp')
            counties = list(reader.geometries())
            COUNTIES = cfeature.ShapelyFeature(counties, ccrs.PlateCarree())
            ax2.add_feature(COUNTIES, facecolor='none', edgecolor='k', linewidth=1, alpha=.4, zorder=13)
            
            ax2.outline_patch.set_linewidth(2)
            ax2.outline_patch.set_zorder(99)
                
            # Title and axis labels
            # ax1.set_title(tstamp+' UTC\n', loc='center', fontsize=12, fontweight='bold',zorder=10)
            
            cbar = fig.colorbar(flux_plot, cax=cax, orientation='vertical')
            ticks = [0,2,4,6,8,10]
            cbar.set_ticks(ticks)
            cbar.set_label(r'$\mathregular{H_{s} [kW m^{-2}]}$', fontsize=8, fontweight='heavy')
            cbar.ax.xaxis.set_label_position('bottom') #if horizontal
            
            # Title and axis labels
            ax2.set_title(dt.datetime.strftime(dt_obj,'%m-%d-%Y %H%:%M'), fontsize=12, fontweight='bold')
        
        
                # ax3.set_title(tstamp+' UTC\n', loc='center', fontsize=12, zorder=10)
                
                ##################
            #     plt.show()
            # sys.exit()
        
            savetime = dt.datetime.strptime(tstamp, "%Y-%m-%d %H:%M:%S")
            savetime = dt.datetime.strftime(savetime, "%Y%m%d_%H%M")
            # save
            plt.savefig(savepath+'/'+keyword+'_GOEScomp_'+PlotType[p]+ff+'_'+savetime+'.png',bbox_inches='tight',dpi=260)
            plt.close('all')
            print('Saved '+keyword+'_GOEScomp_'+PlotType[p]+ff+'_'+savetime+'.png')
            
    """
    # Inset map
    ax2 = fig.add_subplot(spec[0:13,29:39], projection=ccrs.PlateCarree()) #rows,cols
    # extent = [-121.5, -120.85, 39.44, 39.91]
    # ax2.set_extent(extent, crs=ccrs.PlateCarree())
    terMap = ax2.contourf(lon,lat,terrain, cmap='terrain', extend='both',
                          levels=np.arange(-400,2400,100), zorder=7,
                          transform=ccrs.PlateCarree())
    perimPlot = ax2.contour(flon, flat, lfn1, levels=[0], colors='k', linewidth=3,
                            zorder=8, transform=ccrs.PlateCarree())
    for jjj in [0,-1]:
        Xsect = ax2.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=3,color='white',
                         alpha=.5, zorder=10, transform=ccrs.PlateCarree())
        Xsect = ax2.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=1,color='k',
                         linestyle='--', zorder=10, transform=ccrs.PlateCarree())
    ax2.set_xticks([])
    ax2.set_yticks([])
    # Set border around plot
    ax2.spines['top'].set_zorder(10)
    ax2.spines['bottom'].set_zorder(10)
    ax2.spines['left'].set_zorder(10)
    ax2.spines['right'].set_zorder(10)
    """
    
    # plt.show()
        # sys.exit()
        # plt.savefig('/Users/matthewroberts/Desktop/tfigs/'+tstamp+'_base_xsect.png',dpi=120)
    
        # create timestamp for saving files
        # savetime = dt.datetime.strptime(tstamp, "%Y-%m-%d %H:%M:%S")
        # savetime = dt.datetime.strftime(savetime, "%Y%m%d_%H%M")
        # # save
        # plt.savefig(savepath+'/'+keyword+'_'+savetime+'.png',bbox_inches='tight',dpi=120)
        # plt.close('all')
        # print('Saved '+keyword+'_'+savetime+'.png')

# np.save(mainpath+'/data/'+keyword+'_normMaxHs.npy', np.asarray(norm_max))

end_time = dt.datetime.now()
print('===================')
print('Total elapsed time: {0}\n'.format(end_time-start_time))